/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef FUSION_ENGINE_OPTIMIZER_GRAPH_OPTIMIZER_UB_FUSION_BUFFER_FUSION_PASS_RUNNER_H_
#define FUSION_ENGINE_OPTIMIZER_GRAPH_OPTIMIZER_UB_FUSION_BUFFER_FUSION_PASS_RUNNER_H_

#include <algorithm>
#include <map>
#include <string>
#include <vector>
#include "common/fe_log.h"
#include "common/fe_utils.h"
#include "common/math_util.h"
#include "common/scope_allocator.h"
#include "common/util/constants.h"
#include "common/util/op_info_util.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/node.h"
#include "graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h"
#include "reachability_map.h"
#include "register/graph_optimizer/graph_fusion/graph_pass.h"

namespace fe {
using ScopeAllocatorPtr = std::shared_ptr<ScopeAllocator>;
using BufferFusionPassBasePtr = std::unique_ptr<BufferFusionPassBase>;
class BufferFusionPassRunner : public GraphPass {
 public:
  BufferFusionPassRunner(const string &name, BufferFusionPassBase *(*create_fn)(),
                         const ScopeAllocatorPtr &scope_allocat_ptr, std::shared_ptr<ReachabilityMap> &reachability);
  virtual ~BufferFusionPassRunner();

  /**
   * @ingroup fe
   * @brief Distinguish pattern and do fusion
   */
  Status Run(ge::ComputeGraph &graph) override;

 private:
  /*
   * @brief: match one pattern, and do fusion for the matched node
   * @param [in] graph: graph node
   * @param [in] pattern: fusion pattern info
   * @return bool: match current pattern ok or not
   */
  bool RunOnePattern(ge::ComputeGraph &graph, BufferFusionPattern &pattern);

  /*
   * @brief: check if is TVM type op
   * @param [in] node: node
   * @return bool: check result
   */
  bool IsTbeOp(ge::NodePtr node);

  /*
   * @brief: get a node's type presented by a enum type
   * @param [in] node: graph node
   * @return OPTYPE: type of the node
   */
  bool NodeType(ge::NodePtr node);

  /*
   * @brief: check if is Valid op for UB fusion
   * @param [in] node: graph node
   * @return bool: check result
   */
  bool NeedIgnoreOp(ge::NodePtr node);

  /*
   * @brief: get the optype of a node
   * @param [in] node: graph node
   * @param [out] op_type: type represent by std::string
   * @return bool: get op type ok or not
   */
  bool GetOpAttrType(ge::NodePtr node, std::string &op_type);

  /*
   * @brief: check whether node output size is same with candidate desc output
   * size
   * @param [in] node: graph node
   * @param [in] op_desc: candidated pattern desc
   * @return bool: check result
   */
  bool SkipDiffSizeDesc(ge::NodePtr node, const BufferFusionOpDesc *op_desc, const string &pattern_name);

  bool SkipDiffShapeTypeDesc(ge::NodePtr node, const BufferFusionOpDesc *op_desc);

  /*
   * @brief: get current loop fusiton match status
   * @param [in] Is_parallel: graph node is multi branch or single branch
   * @param [in] opdescs: candidated pattern desc
   * @param [in] usage: record whether desc has beed matched
   * @return bool: all current loop descs have beed matched or not
   */
  bool GetCurrMatchStatus(bool Is_parallel, std::vector<BufferFusionOpDesc *> opdescs,
                          std::map<BufferFusionOpDesc *, bool> usage);

  /*
   * @brief: get pattern fusiton match status
   * @param [in] pattern: fusion pattern info
   * @return bool: the pattern has beed matched or not
   */
  bool GetPatternMatchStatus(BufferFusionPattern &pattern);

  /*
   * @brief: get fusiton pattern head desc matched
   * @param [in] node: graph node
   * @param [in] head_descs: candidated head desc list
   * @return BufferFusionOpDesc*: head desc ptr
   */
  BufferFusionOpDesc *GetMatchedHeadDesc(ge::NodePtr node, const string &pattern_name,
                                         std::vector<BufferFusionOpDesc *> head_descs);

  /*
   * @brief: get current loop desc matched
   * @param [in] node: graph node
   * @param [in] head_descs: valid head desc
   * @param [in] usage: record whether desc has beed matched
   * @return BufferFusionOpDesc*: matched desc ptr
   */
  BufferFusionOpDesc *GetMatchedNormalDesc(
      ge::NodePtr node, BufferFusionOpDesc *head_desc, std::vector<BufferFusionOpDesc *> descs,
      std::map<BufferFusionOpDesc *, bool> usage,
      std::map<std::string, std::map<int32_t, std::vector<std::string>>> &matched_output_nodes,
      const string &pattern_name);

  void MatchFusionPattern(vector<BufferFusionOpDesc *> &queue_descs, vector<ge::NodePtr> &queue_nodes,
                          BufferFusionPattern &pattern, BufferFusionMapping &mapping, BufferFusionOpDesc *head_desc);

  void MatchFollowingNodes(ge::NodePtr node, vector<BufferFusionOpDesc *> &queue_descs,
                           vector<ge::NodePtr> &queue_nodes, vector<BufferFusionOpDesc *> &curr_descs,
                           BufferFusionPattern &pattern, map<BufferFusionOpDesc *, bool> &usage_flags,
                           BufferFusionMapping &mapping, BufferFusionOpDesc *head_desc,
                           map<string, map<int32_t, vector<string>>> &matched_output_nodes);

  void RecoverMappingAndQueue(vector<vector<BufferFusionOpDesc *>> &saved_queue_descs,
                              vector<vector<ge::NodePtr>> &saved_queue_nodes, BufferFusionMappings &saved_mappings,
                              vector<BufferFusionOpDesc *> &curr_queue_descs, vector<ge::NodePtr> &curr_queue_nodes,
                              BufferFusionMapping &curr_mapping, bool match_error, BufferFusionMapping &longest_mapping,
                              size_t &longest_num, BufferFusionPattern &pattern);

  void CompareMappings(BufferFusionMapping &curr_mapping, BufferFusionMapping &longest_mapping, size_t &longest_num);

  bool CheckLoopForward(BufferFusionMapping &mapping, ge::NodePtr &targetnode);

  bool IsOptionalOutput(BufferFusionOpDesc *desc);

  bool SkipNodeForNormalDesc(std::map<std::string, std::map<int32_t, std::vector<std::string>>> &matched_output_nodes,
                             BufferFusionOpDesc *out_desc, std::string node_name, ge::NodePtr node,
                             BufferFusionOpDesc *head_desc, int64_t loop_num, const string &pattern_name);

  bool SkipNodeBeforeMatch(const ge::NodePtr &node, size_t curr_node_num, size_t curr_desc_num,
                           BufferFusionOpDesc *op_desc, BufferFusionOpDesc *head_desc, bool get_output_result,
                           const string &pattern_name);

  void SaveQueueBeforeMatch(std::vector<BufferFusionOpDesc *> &curr_descs, ge::NodePtr node,
                            BufferFusionOpDesc *op_desc, std::vector<BufferFusionOpDesc *> &queue_descs,
                            std::vector<ge::NodePtr> &queue_nodes, BufferFusionMapping &mapping,
                            vector<vector<BufferFusionOpDesc *>> &saved_queue_descs,
                            vector<vector<ge::NodePtr>> &saved_queue_nodes, BufferFusionMappings &saved_mappings,
                            uint32_t &saved_count);

  void GetExistingFusionScopes(ge::ComputeGraph &graph, std::map<int64_t, vector<ge::NodePtr>> &fusion_scopes);

  /*
   * @brief: check whether graph node is matched with pattern desc
   * @param [in] node: graph node
   * @param [in] op_desc: candidated pattern desc
   * @return bool: check result
   */
  bool IsOpTypeExist(const ge::NodePtr node, const BufferFusionOpDesc *op_desc);

  bool IsOpTypeAny(const vector<string> &types);

  bool IsOutputNode(const vector<string> &types);

  bool CheckAttrMatch(BufferFusionMapping &mapping);

  void SetScopeIdAndPassName(const vector<ge::NodePtr> &fusion_nodes, const string &pass_name,
                             const string &pattern_name);

  Status MatchFromHead(const ge::NodePtr &node_g, BufferFusionPattern &pattern,
                       BufferFusionMapping &mapping);

  void InitRepeatCurr(const std::vector<BufferFusionOpDesc *> &ops);

  bool CheckCubeVectorSplit(vector<ge::NodePtr> &fusion_nodes);

  const int TBE_MATCH_LOOP_NUM = 2;

  ScopeAllocatorPtr scope_allocator_ptr_;
  vector<BufferFusionPattern *> patterns_;
  BufferFusionPassBasePtr buffer_fusion_pass_base_ptr_;
  std::shared_ptr<ReachabilityMap> reachability_;
  std::set<std::string> cube_op_type_;
};

}  // namespace fe

#endif  // FUSION_ENGINE_OPTIMIZER_GRAPH_OPTIMIZER_UB_FUSION_BUFFER_FUSION_PASS_RUNNER_H_
